from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import pprint
import shutil
import signal

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter


import _init_paths
from lib.core.config import config
from lib.core.config import update_config
from lib.core.config import get_model_name
from lib.core.loss import JointsMSELoss
from lib.core.function import train
from lib.core.function import validate
from lib.utils.utils import get_optimizer
from lib.utils.utils import save_checkpoint
from lib.utils.utils import create_logger
from lib.utils.utils import get_training_loader, get_training_set

import torch.distributed as dist
import torch.utils.data.distributed
import torch.nn.parallel
import torch.multiprocessing as mp

import lib.dataset as dataset
import lib.models as models

def init_seeds(seed=0):
    torch.manual_seed(seed) # sets the seed for generating random numbers.
    torch.cuda.manual_seed(seed) # Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.
    torch.cuda.manual_seed_all(seed) # Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def signal_handler(sig, frame):
    assert 0, 'This process is killed manually!'


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)
    parser.add_argument('--port', default=12590, type=int)

    args, rest = parser.parse_known_args()
    # update config
    update_config(args.cfg)

    # training
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=config.PRINT_FREQ,
                        type=int)
    parser.add_argument('--gpus',
                        help='gpus',
                        type=str)
    parser.add_argument('--workers',
                        help='num of dataloader workers',
                        type=int)
    parser.add_argument('--iteration',
                        help='the kth times of training',
                        type=int,
                        choices=range(1, 10),
                        default=1)

    args = parser.parse_args()

    return args


def reset_config(config, args):
    if args.gpus:
        config.GPUS = args.gpus
    if args.workers:
        config.WORKERS = args.workers


def main():
    args = parse_args()
    reset_config(config, args)
    init_seeds(config.SEED)

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    gpus = [int(i) for i in config.GPUS.split(',')]
    num_gpus = len(gpus)
    assert num_gpus <= torch.cuda.device_count(), 'available GPUS: {}, designate GPUs: {}'.format(
        torch.cuda.device_count(), num_gpus)

    print('=> initializing multiple processes')
    mp.spawn(main_worker, nprocs=num_gpus, args=(args, config, num_gpus))


def main_worker(rank, args, config, num_gpus):
    os.environ['MASTER_ADDR'] = '0.0.0.0'
    os.environ['MASTER_PORT'] = str(args.port)
    dist.init_process_group(backend='gloo', rank=rank, world_size=num_gpus)
    # TODO manual seed here
    init_seeds(config.SEED)  # new added
    print('Rank: {} finished initializing, PID: {}'.format(rank, os.getpid()))

    if rank == 0:
        logger, final_output_dir, tb_log_dir = create_logger(
            config, args.cfg, 'train')
        logger.info(pprint.pformat(args))
        logger.info(pprint.pformat(config))
    else:
        final_output_dir = None
        tb_log_dir = None

    # Gracefully kill all subprocesses by command <'kill subprocess 0'>
    signal.signal(signal.SIGTERM, signal_handler)
    if rank == 0:
        logger.info('Rank {} has registerred signal handler'.format(rank))
        # copy model file
        this_dir = os.path.dirname(__file__)
        shutil.copy2(
            os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
            final_output_dir)

    # device in current process
    gpus = [int(i) for i in config.GPUS.split(',')]
    device = torch.device('cuda', gpus[rank])

    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
        config, is_train=True
    ).to(device)

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    model = torch.nn.DataParallel(model, device_ids=[gpus[rank]], output_device=rank)
    dist.barrier()
    # resume
    if config.TRAIN.RESUME:
        assert config.TRAIN.RESUME_PATH != '', 'You must designate a path for config.TRAIN.RESUME_PATH, rank: {}'.format(
            rank)
        if rank == 0:
            logger.info('=> loading model from {}'.format(config.TRAIN.RESUME_PATH))
        # !!! map_location must be cpu, otherwise a lot memory will be allocated on gpu:0.
        check_point = torch.load(config.TRAIN.RESUME_PATH, map_location=torch.device('cpu'))
        if 'state_dict' in check_point.keys():
            model.load_state_dict(check_point['state_dict'], strict=False)
        else:
            model.load_state_dict(check_point, strict=False)


    # Traing on server cluster, resumed when interrupted
    start_epoch = config.TRAIN.BEGIN_EPOCH

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=config.LOSS.USE_TARGET_WEIGHT
    ).cuda()
    optimizer = get_optimizer(config, model)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR
    )

    # Data loading code
    if rank == 0:
        logger.info('=> loading dataset')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = get_training_set(config, transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))
    valid_dataset = eval('dataset.' + config.DATASET.TEST_DATASET[0].DATASET)(
        config,
        config.DATASET.TEST_DATASET[0],
        False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    train_loader, train_sampler = get_training_loader(train_dataset, config)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.TEST.BATCH_SIZE,
        shuffle=False,
        num_workers=int(config.WORKERS / num_gpus),
        pin_memory=True
    )

    best_perf = 0.0
    best_model = False
    dist.barrier()

    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):

        train_sampler.set_epoch(epoch)

        # train for one epoch
        train(config, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, writer_dict, rank, gpus[rank])

        if rank == 0:
            # evaluate on validation set
            perf_indicator = validate(config, valid_loader, valid_dataset, model,
                                      criterion, final_output_dir,
                                      writer_dict, rank, gpus[rank])
            if perf_indicator > best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            logger.info('=> best perf_indicator is {}'.format(best_perf))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': get_model_name(config),
                'state_dict': model.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
                'iteration': args.iteration
            }, best_model, final_output_dir)

            lr_scheduler.step()
        try:
            dist.barrier()
        except Exception:
            if not mp.active_children():    # to kill zombie child process in case it exits abnormally
                return

    if rank == 0:
        final_model_state_file = os.path.join(final_output_dir,
                                              'final_state.pth.tar')
        logger.info('saving final model state to {}'.format(
            final_model_state_file))
        torch.save(model.state_dict(), final_model_state_file)
        writer_dict['writer'].close()

    print('Rank {} exit'.format(rank))


if __name__ == '__main__':
    main()
